#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 19 15:54:38 2022

@author: qiguangyao
"""


#%%Lib
import copy
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns 
from scipy import asarray as ar,exp
from scipy.optimize import curve_fit
import math
import pingouin as pg
from sklearn import linear_model
from pylab import cos
import pandas as pd
import random
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from statsmodels.sandbox.stats.multicomp import multipletests # for multiple comparisons correction
from statsmodels.stats.multicomp import pairwise_tukeyhsd
print("__file Output:",__file__)
#%%functions
import scipy.stats
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 10))  # outward by 10 points
        else:
            spine.set_color('none')  # don't draw spine

    # turn off ticks where there is no spine
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        # no yaxis ticks
        ax.yaxis.set_ticks([])

    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        # no xaxis ticks
        ax.xaxis.set_ticks([])
        
def gaus(x,a,x0,sigma):
    return a*(1/sigma*np.sqrt(2*np.pi))*exp(-(x-x0)**2/(2*sigma**2))

def gaussian(X, amp, cen, wid):
    return amp * exp(-(X-cen)**2 / wid)

def getPossionPDF(mu,x):
    if x > 170:
        x =170
    mu = mu + 0.01
    if x<0:
        x = 0
    # x[x<0]=0
    x = copy.deepcopy(round(x))
    out = math.exp(-mu)*(mu**x)/math.factorial(x)
    if out<0:
        out = 0
    return out

#tuning curve fitting
def vonMisesFunction(x,b,a,u):
    # import math
#    print(x - u)
    out = b + a*cos(x - u)
    out = np.array(out)
    out[out<0]=0
    # if out<0:
    #     out = 0
    return out

def getvonMisesParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(vonMisesFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def getExpParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(expFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def expFunction(x, a, b, c):
    return a * np.exp(-b * x) + c
#tuning curve fitting
def vonMisesFunction(x,b,a,u):
    # import math
#    print(x - u)
    out = b + a*cos(x - u)
    out = np.array(out)
    out[out<0]=0
    # if out<0:
    #     out = 0
    return out

def getvonMisesParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 1, 0]  # for [b,a,u]
    best_vals, covar = curve_fit(vonMisesFunction, x, y, 
                                 # bounds=((-10, 1, -np.pi), (10, 10, np.pi)),
                                 p0=init_vals,maxfev=500000)
    return best_vals

def getExpParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 1, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(expFunction, x, y, p0=init_vals,maxfev=500000000)
    return best_vals
#%% ------------figure 3---------------- 
fig3Data = pkl.load(open('fig3Data.pickle','rb'))
#3C left panel
omga2PEVTarVPPMC = fig3Data['omga2PEVTarVPPMC']
omga2PEVTarVPArea5 = fig3Data['omga2PEVTarVPArea5']
omga2PEVTarPPMC = fig3Data['omga2PEVTarPPMC']
omga2PEVTarPArea5 = fig3Data['omga2PEVTarPArea5']
omga2PEVTarVPMean = [np.nanmean(omga2PEVTarVPPMC), np.nanmean(omga2PEVTarVPArea5)]
omga2PEVTarVPSEM = [np.std(omga2PEVTarVPPMC)/(np.sqrt(len(omga2PEVTarVPPMC))),
                    np.std(omga2PEVTarVPArea5)/(np.sqrt(len(omga2PEVTarVPArea5)))]
omga2PEVTarPMean = [np.nanmean(omga2PEVTarPPMC), np.nanmean(omga2PEVTarPArea5)]
omga2PEVTarPSEM = [np.std(omga2PEVTarPPMC)/(np.sqrt(len(omga2PEVTarPPMC))),
                    np.std(omga2PEVTarPArea5)/(np.sqrt(len(omga2PEVTarPArea5)))]
print(stats.ranksums(omga2PEVTarVPPMC,omga2PEVTarVPArea5))
print(stats.ranksums(omga2PEVTarPArea5,omga2PEVTarPPMC))
pColl1 = []
pColl1.append(stats.ranksums(omga2PEVTarVPPMC,omga2PEVTarVPArea5)[1])
pColl1.append(stats.ranksums(omga2PEVTarPArea5,omga2PEVTarPPMC)[1])
correction_method = 'fdr_bh'
alpha = 0.05
rejectList_1, PValsMulTest_1 = multipletests(pColl1, method=correction_method, alpha=alpha)[:2]

print(stats.wilcoxon(np.array(omga2PEVTarVPPMC).reshape(len(omga2PEVTarVPPMC))))
print(stats.wilcoxon(np.array(omga2PEVTarPPMC).reshape(len(omga2PEVTarPPMC))))

print(stats.wilcoxon(np.array(omga2PEVTarVPArea5).reshape(len(omga2PEVTarVPArea5))))
print(stats.wilcoxon(np.array(omga2PEVTarPArea5).reshape(len(omga2PEVTarPArea5))))


#3C right panel
lateCorporalHandPEVPMC = fig3Data['lateCorporalHandPEVPMC']
lateCorporalHandPEVArea5 = fig3Data['lateCorporalHandPEVArea5']
lateVisualHandPEVPMC = fig3Data['lateVisualHandPEVPMC']
lateVisualHandPEVArea5 = fig3Data['lateVisualHandPEVArea5']
lateCorporalHandPEVMean = [np.nanmean(lateCorporalHandPEVPMC),np.nanmean(lateCorporalHandPEVArea5)]
lateCorporalHandPEVSEM = [np.nanstd(lateCorporalHandPEVPMC)/np.sqrt(len(lateCorporalHandPEVPMC)),
                          np.nanstd(lateCorporalHandPEVArea5)/np.sqrt(len(lateCorporalHandPEVArea5))]
lateVisualHandPEVMean = [np.nanmean(lateVisualHandPEVPMC),np.nanmean(lateVisualHandPEVArea5)]
lateVisualHandPEVSEM = [np.nanstd(lateVisualHandPEVPMC)/np.sqrt(len(lateVisualHandPEVPMC)),
                        np.nanstd(lateVisualHandPEVArea5)/np.sqrt(len(lateVisualHandPEVArea5))]
print(stats.ranksums(lateCorporalHandPEVPMC,lateCorporalHandPEVArea5))
print(stats.ranksums(lateVisualHandPEVPMC,lateVisualHandPEVArea5))

print(stats.wilcoxon(np.array(lateCorporalHandPEVArea5).reshape(len(lateCorporalHandPEVArea5))))
print(stats.wilcoxon(np.array(lateVisualHandPEVArea5).reshape(len(lateVisualHandPEVArea5))))

pColl4 = []
pColl4.append(stats.ranksums(lateCorporalHandPEVPMC,lateCorporalHandPEVArea5)[1])
pColl4.append(stats.ranksums(lateVisualHandPEVPMC,lateVisualHandPEVArea5)[1])
correction_method = 'fdr_bh'
alpha = 0.05
rejectList_4, PValsMulTest_4 = multipletests(pColl4, method=correction_method, alpha=alpha)[:2]

#3D
pMeanFRSmooTrajPMC = fig3Data['pMeanFRSmooTrajPMC']
vpHandMeanFRSmooTrajPMC = fig3Data['vpHandMeanFRSmooTrajPMC']
vpcHandMeanFRSmooTrajPMC = fig3Data['vpcHandMeanFRSmooTrajPMC']

pEventIndexRasterPMC = fig3Data['pEventIndexRasterPMC']
vpHandEventIndexRasterPMC = fig3Data['vpHandEventIndexRasterPMC']
vpcHandEventIndexRasterPMC = fig3Data['vpcHandEventIndexRasterPMC']
neuralEventReachSuccessRasterPMC = fig3Data['neuralEventReachSuccessRasterPMC']
infoRasterPMC = fig3Data['infoRasterPMC']

pMeanFRSmooTrajArea5 = fig3Data['pMeanFRSmooTrajArea5']
vpHandMeanFRSmooTrajArea5 = fig3Data['vpHandMeanFRSmooTrajArea5']
vpcHandMeanFRSmooTrajArea5 = fig3Data['vpcHandMeanFRSmooTrajArea5']

pEventIndexRasterArea5 = fig3Data['pEventIndexRasterArea5']
vpHandEventIndexRasterArea5 = fig3Data['vpHandEventIndexRasterArea5']
vpcHandEventIndexRasterArea5 = fig3Data['vpcHandEventIndexRasterArea5']
neuralEventReachSuccessRasterArea5 = fig3Data['neuralEventReachSuccessRasterArea5']
infoRasterArea5 = fig3Data['infoRasterArea5']

#3F
vpWeigDrifDispPMC = fig3Data['vpWeigDrifDispPMC']
vpWeigDrifDispArea5 = fig3Data['vpWeigDrifDispArea5']

#3H
pCommonExamParietal = fig3Data['pCommonExamParietal']
vpWeighExamParietal = fig3Data['vpWeighExamParietal']

#3G
popuPMCVPWeight = fig3Data['popuPMCVPWeight']
popuArea5VPWeight = fig3Data['popuArea5VPWeight']

#3G right
# caulInfeNeurPMC = fig3Data['caulInfeNeurPMC']
# caulInfeNeurArea5 =fig3Data['caulInfeNeurArea5']
caulInfeNeurPMCNum = fig3Data['caulInfeNeurPMCNum']
caulInfeNeurArea5Num = fig3Data['caulInfeNeurArea5Num']

PMCNeurNum = fig3Data['PMCNeurNum']
Area5NeurNum = fig3Data['Area5NeurNum']
#%%fig3C-left boxplot
widt1 = .25
barWidth = 0.2
r1 = np.arange(1)+barWidth
r2 = np.array([x + barWidth for x in r1])
s1 = 25
lw1 = .25
labels = ['PMC','Area5']
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    x_labels = ['Arm(VP)','Arm(P)', 'Proprioceptive arm','Visual arm']
    x = np.arange(2)
    fig, ax = plt.subplots(figsize=(3.54/1.2,3.54/2))
    width = 0.4
    # jets1 = [(0.5-random.random())/7 for i in range(len(omga2PEVTarVPPMC))]
    jets1 = [0 for i in range(len(omga2PEVTarVPPMC))]
    plt.scatter([jets1[i]+x[0] for i in range(len(omga2PEVTarVPPMC))],omga2PEVTarVPPMC,
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[3]
                )
    jets1 = [0 for i in range(len(omga2PEVTarVPArea5))]
    # jets1 = [(0.5-random.random())/7 for i in range(len(omga2PEVTarVPArea5))]
    plt.scatter([jets1[i]+x[0]+width for i in range(len(omga2PEVTarVPArea5))],omga2PEVTarVPArea5,
                alpha = .2,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[4])
    
    # jets1 = [(0.5-random.random())/7 for i in range(len(omga2PEVTarPPMC))]
    jets1 = [0 for i in range(len(omga2PEVTarPPMC))]

    plt.scatter([jets1[i]+x[1] for i in range(len(omga2PEVTarPPMC))],omga2PEVTarPPMC,
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[3]
                )
    # jets1 = [(0.5-random.random())/7 for i in range(len(omga2PEVTarPArea5))]
    jets1 = [0 for i in range(len(omga2PEVTarPArea5))]
    plt.scatter([jets1[i]+x[1]+width for i in range(len(omga2PEVTarPArea5))],omga2PEVTarPArea5,
                alpha = .2,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[4])

    bplot1 = plt.boxplot([i for i in omga2PEVTarVPPMC],positions = [x[0]],
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k'},
                patch_artist=True)
    bplot2 = plt.boxplot([i for i in omga2PEVTarVPArea5],positions = [x[0]+width],
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k',
                               # 'linewidth':.5
                               },
                patch_artist=True)

    bplot3 = plt.boxplot([i for i in omga2PEVTarPPMC],positions = [x[1]],
                showfliers=False,
                # color = 'k',
                widths = widt1,
                medianprops = {'color':'k'},
                patch_artist=True)
    bplot4 = plt.boxplot([i for i in omga2PEVTarPArea5],positions = [x[1]+width],
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k',
                               # 'linewidth':.5
                               },
                patch_artist=True)

    
    for patch in bplot1['boxes']:
        patch.set(color=colors[3])
        patch.set(facecolor=colors[3])
        patch.set(alpha = .5)
    for patch in bplot2['boxes']:
        patch.set(color=colors[4])
        patch.set(facecolor=colors[4])
        patch.set(alpha = .5)

    for patch in bplot3['boxes']:
        patch.set(color=colors[3])
        patch.set(facecolor=colors[3])
        patch.set(alpha = .5)
    for patch in bplot4['boxes']:
        patch.set(color=colors[4])
        patch.set(facecolor=colors[4])
        patch.set(alpha = .5)
    
    # for i in range(2):
    #     ax.bar(x + i*width, [omga2PEVTarVPMean[i],omga2PEVTarPMean[i]],
    #     width, label=labels[i], yerr=[omga2PEVTarVPSEM[i],omga2PEVTarPSEM[i]],
    #     capstyle = 'round', color=colors[i+3])
    # ax.set_ylabel('ωPEV')
    plt.xticks(rotation=0)
    plt.yticks(fontsize = 8)
    plt.xticks(fontsize = 8)
    ax.set_xlim([-.25,1.65])
    # ax.set_xlim([-.5,1.9])
    ax.set_xticks(x + width/2)
    # ax.set_xticklabels(x_labels[0:2])
    # plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    # ax.yaxis.major.formatter._useMathText = True
    adjust_spines(ax, ['left', 'bottom'])
    ax.set_ylabel('ωPEV',fontsize = 8)
    fig.tight_layout()
    fileName = 'fig3C_lateVPPArms.pdf'
    plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig3C-right boxplot
widt1 = .25
barWidth = 0.2
r1 = np.arange(1)+barWidth
r2 = np.array([x + barWidth for x in r1])
s1 = 25
lw1 = .25


# widt1 = .25
# r1 = np.arange(1)+barWidth
# r2 = np.array([x + barWidth for x in r1])
# s1 = 25
# lw1 = .25
labels = ['PMC','Area5']
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    x_labels = ['Arm (VP)','Arm (P)', 'Proprioceptive\narm(VPC)','Visual\narm(VPC)']
    x = np.arange(2)
    fig, ax = plt.subplots(figsize=(3.54/1.2,3.54/2))
    width = 0.4
    # jets1 = [(0.5-random.random())/7 for i in range(len(lateCorporalHandPEVPMC))]
    jets1 = [0 for i in range(len(lateCorporalHandPEVPMC))]


    
    plt.scatter([jets1[i]+x[0] for i in range(len(lateCorporalHandPEVPMC))],lateCorporalHandPEVPMC,
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[3]
                )
    # jets1 = [(0.5-random.random())/7 for i in range(len(lateCorporalHandPEVArea5))]
    jets1 = [0 for i in range(len(lateCorporalHandPEVArea5))]

    plt.scatter([jets1[i]+x[0]+width for i in range(len(lateCorporalHandPEVArea5))],lateCorporalHandPEVArea5,
                alpha = .2,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[4])
    
    # jets1 = [(0.5-random.random())/7 for i in range(len(lateVisualHandPEVPMC))]
    jets1 = [0 for i in range(len(lateVisualHandPEVPMC))]

    plt.scatter([jets1[i]+x[1] for i in range(len(lateVisualHandPEVPMC))],lateVisualHandPEVPMC,
                alpha = .5,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[3]
                )

    # jets1 = [(0.5-random.random())/7 for i in range(len(lateVisualHandPEVArea5))]
    jets1 = [0 for i in range(len(lateVisualHandPEVArea5))]

    plt.scatter([jets1[i]+x[1]+width for i in range(len(lateVisualHandPEVArea5))],lateVisualHandPEVArea5,
                alpha = .2,
                s = s1,
                color = [],
                marker = 'o',
                edgecolors = colors[4])

    bplot1 = plt.boxplot(lateCorporalHandPEVPMC,positions = [x[0]],
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k'},
                patch_artist=True)
    bplot2 = plt.boxplot(lateCorporalHandPEVArea5,positions = [x[0]+width],
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k',
                               # 'linewidth':.5
                               },
                patch_artist=True)

    bplot3 = plt.boxplot(lateVisualHandPEVPMC,positions = [x[1]],
                showfliers=False,
                # color = 'k',
                widths = widt1,
                medianprops = {'color':'k'},
                patch_artist=True)
    bplot4 = plt.boxplot(lateVisualHandPEVArea5,positions = [x[1]+width],
                showfliers=False,
                widths = widt1,
                medianprops = {'color':'k',
                               # 'linewidth':.5
                               },
                patch_artist=True)

    
    for patch in bplot1['boxes']:
        patch.set(color=colors[3])
        patch.set(facecolor=colors[3])
        patch.set(alpha = .5)
    for patch in bplot2['boxes']:
        patch.set(color=colors[4])
        patch.set(facecolor=colors[4])
        patch.set(alpha = .5)

    for patch in bplot3['boxes']:
        patch.set(color=colors[3])
        patch.set(facecolor=colors[3])
        patch.set(alpha = .5)
    for patch in bplot4['boxes']:
        patch.set(color=colors[4])
        patch.set(facecolor=colors[4])
        patch.set(alpha = .5)
    
    # for i in range(2):
    #     ax.bar(x + i*width, [omga2PEVTarVPMean[i],omga2PEVTarPMean[i]],
    #     width, label=labels[i], yerr=[omga2PEVTarVPSEM[i],omga2PEVTarPSEM[i]],
    #     capstyle = 'round', color=colors[i+3])
    ax.set_ylabel('ωPEV',fontsize = 8)
    plt.xticks(rotation=0)
    plt.xticks(fontsize = 8)
    plt.yticks(fontsize = 8)
    # ax.set_xlim([-.5,1.9])
    ax.set_xlim([-.25,1.65])

    ax.set_xticks(x + width/2)
    # ax.set_xticks([])
    # ax.set_xticklabels(x_labels[2:4])
    # plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    # ax.yaxis.major.formatter._useMathText = True
    adjust_spines(ax, ['left', 'bottom'])

    fig.tight_layout()
    fileName = 'fig3C_lateVPCArms.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%% fig3D raster
#Area5
key =['20180505S1SPK01a']#Area5
pEventIndex = copy.deepcopy(pEventIndexRasterArea5)
vpHandEventIndex = copy.deepcopy(vpHandEventIndexRasterArea5)
vpcHandEventIndex = copy.deepcopy(vpcHandEventIndexRasterArea5)
neuralEventReachSuccess = copy.deepcopy(neuralEventReachSuccessRasterArea5)
info = copy.deepcopy(infoRasterArea5)

disparitys = np.array([-45,-35,-20,-10,0,10,20,35,45])  
disparitysEvent = np.array([45,-45,35,-35,20,-20,10,-10,0])
# disparitysEvent = copy.deepcopy(disparitys)
colors = plt.cm.viridis(np.linspace(0,1,9))
plotTrial = 1
plotTrialLate = 0
# widthRast = 0.01
# hightRast = 1
widthRast = 0.001#0.000005
hightRast = 0.01#.000005
# with plt.rc_context(params): 
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    colorsDisp = plt.cm.jet(np.linspace(0,1,9))
    # colorsDisp = plt.cm.viridis(np.linspace(0,1,9))
    # f, ax1 = plt.subplots(figsize=[3.54/1.5,3.54/2])
    f, ax1 = plt.subplots(figsize=[3.54/3,3.54/2])
    for i in pEventIndex:
        # ax1.plot(neuralEventReachSuccess[i,:],[plotTrial for j in range(neuralEventReachSuccess.shape[1])],'|',lineWidth = .5,color = colorsDisp[0])
        for k in range(len(neuralEventReachSuccess[i,:])):
            rect = plt.Rectangle((neuralEventReachSuccess[i,k]-widthRast/2,
                                  plotTrial-hightRast/2),widthRast,hightRast,
                                 color ='k')#colorsDisp[0]
            ax1.add_patch(rect)
        plotTrial += 1

    for di in range(len(disparitysEvent)):
        if disparitysEvent[di] == -45:
            coloSele = colorsDisp[0]
        elif disparitysEvent[di] == -35:
            coloSele = colorsDisp[1]
        if disparitysEvent[di] == -20:
            coloSele = colorsDisp[2]
        elif disparitysEvent[di] == -10:
            coloSele = colorsDisp[3]
        if disparitysEvent[di] == 0:
            coloSele = colorsDisp[4]
        elif disparitysEvent[di] == 10:
            coloSele = colorsDisp[5]
        if disparitysEvent[di] == 20:
            coloSele = colorsDisp[6]
        elif disparitysEvent[di] == 35:
            coloSele = colorsDisp[7]
        elif disparitysEvent[di] == 45:
            coloSele = colorsDisp[8]
        k = 0
        
        for i in vpcHandEventIndex:
            if info[i,6] == disparitysEvent[di]:
                for k in range(len(neuralEventReachSuccess[i,:])):
                    rect = plt.Rectangle((neuralEventReachSuccess[i,k]-widthRast/2,plotTrial-hightRast/2),widthRast,hightRast,
                                         color =coloSele)#colorsDisp[di])#colors[1])# '#1226aa)#color = '#1f77b4')
                    ax1.add_patch(rect)
                plotTrial += 1
                k+=1
    for i in vpHandEventIndex:
        # if i == vpHandEventIndex[0]:
        for k in range(len(neuralEventReachSuccess[i,:])):
            rect = plt.Rectangle((neuralEventReachSuccess[i,k]-widthRast/2,
                                  plotTrial-hightRast/2),
                                 widthRast,hightRast,
                                 color = 'gray')#colorsDisp[10]
            ax1.add_patch(rect)
        plotTrial += 1
    ax1.plot([-0.5 for jj in range(plotTrial)],[jj for jj in range(plotTrial)],'-',lw = 1/2,color = 'k')
    ax1.plot([0 for jj in range(plotTrial)],[jj for jj in range(plotTrial)],'-',lw = 1/2,color = 'k')
    # ax1.text(0, plotTrial+6, 'Holding success', horizontalalignment='center',size=8,color = 'k')
    ax1.set_ylabel('Trial #')#,fontsize=8)
    ax1.set_xlabel('Time (s)')#,fontsize=8)
    # ax1.set_xticks(fontsize=5)
    # ax1.set_yticks(fontsize=5)
    # ax1.tick_params(fontsize=5)
    # plt.xticks(fontsize=8)
    # plt.yticks(fontsize=8) 
    plt.ylim(bottom = 0,top = plotTrial+1)
    # plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    ax1.set_xlim([-1,0.05])
    ax1.set_xticks(np.arange(-1,0.05,.5))
    plt.tight_layout()
    # plt.title(key)
    #save figure
    # fileName = 'fig3D_Area5'+'ExampleNeuronRaster'+key+'.pdf'#Area5
    # plt.savefig(fileName,dpi=300)
#%% fig3D neural trajecotry
disparitysEvent = np.array([45,-45,35,-35,20,-20,10,-10,0])
#Area5
pMeanFRSmoo = copy.deepcopy(pMeanFRSmooTrajArea5)
vpHandMeanFRSmoo = copy.deepcopy(vpHandMeanFRSmooTrajArea5)
vpcHandMeanFRSmoo = copy.deepcopy(vpcHandMeanFRSmooTrajArea5)
key = '20180505S1SPK01a'
with plt.style.context('style_paper.mplstyle'):
    colorsDisp = plt.cm.jet(np.linspace(0,1,9))
    #plot
    with plt.style.context('/Users/qiguangyao/neuralDataAnalysisProtocolData/style_paper.mplstyle'):
        f, ax1 = plt.subplots(figsize=[3.54/3,3.54/2])
        ax1.plot([i*.1-.8-1.3 for i in range(22)],pMeanFRSmoo,label = 'P',color = 'k',lw = 1,ls = '--')
        for i in range(9):
            if disparitysEvent[i] == -45:
                coloSele = colorsDisp[0]
            elif disparitysEvent[i] == -35:
                coloSele = colorsDisp[1]
            if disparitysEvent[i] == -20:
                coloSele = colorsDisp[2]
            elif disparitysEvent[i] == -10:
                coloSele = colorsDisp[3]
            if disparitysEvent[i] == 0:
                coloSele = colorsDisp[4]
            elif disparitysEvent[i] == 10:
                coloSele = colorsDisp[5]
            if disparitysEvent[i] == 20:
                coloSele = colorsDisp[6]
            elif disparitysEvent[i] == 35:
                coloSele = colorsDisp[7]
            elif disparitysEvent[i] == 45:
                coloSele = colorsDisp[8]
            ax1.plot([i*.1-.8-1.3 for i in range(22)],
                     vpcHandMeanFRSmoo[i,:],
                     color = coloSele,#colorsDisp[i],
                     label = disparitysEvent[i],lw = 1)
        ax1.plot([i*.1-.8-1.3 for i in range(22)],vpHandMeanFRSmoo,label = 'VP',color = 'gray',lw = 1,ls = '--')
        # ax1.fill_between([i*.1-.8-1.3 for i in range(16,22)],[0 for i in range(16,22)],[45 for i in range(16,22)], 
        #                  edgecolor=[], facecolor='gray',alpha=0.4)#Area5    
        ax1.fill_between([i*.1-.8-1.3 for i in range(16,22)],[0 for i in range(16,22)],[3.1 for i in range(16,22)], 
                         edgecolor=[], facecolor='gray',alpha=0.4) #PMC
        ax1.set_xlabel('Time (s)')
        ax1.set_ylabel('Firing rate (Hz)')
        # ax1.set_xlim([-1,1])
        # ax1.set_xlim([-1,.2])
        ax1.set_xlim([-1,0.05])
        ax1.set_xticks(np.arange(-1,0.05,.5))
        
        # ax1.legend(loc='best', 
        #            bbox_to_anchor=(0.5, 1),
        #            ncol = 1,prop={'size': 8})#
        # ax1.set_xticks(np.arange(-0,1.1,.5))    
        # plt.xticks(fontsize=8)
        # plt.yticks(fontsize=8)
        plt.tight_layout()
        # plt.title(key)
        plt.ylim(bottom = 0)#PMC
        
        #save figure
        fileName = 'fig3D_exampleNeuronNeuralTrajectoryArea5'+key+'.pdf'
        # plt.savefig(fileName,dpi=300)
    plt.show()
#%%

#%% fig3D tuning curve
disparitysEventRad = []
for i in range(9):
    disparitysEventRad.append(math.radians(disparitysEvent[i]))
parasVPC = getvonMisesParas(disparitysEventRad,np.nanmean(vpcHandMeanFRSmoo[:,range(22-5,22)],1))

rangeDire = np.linspace(-0.88, 0.88, 40)
vpcHandFRMeanFitting = []
for i in range(len(rangeDire)):
    vpcHandFRMeanFitting.append(vonMisesFunction(rangeDire[i],parasVPC[0],parasVPC[1],parasVPC[2]).tolist())

rangeDireDeg = [math.degrees(i) for i in rangeDire]
with plt.style.context('style_paper.mplstyle'):
    plt.figure(figsize=[3.54/1.5,3.54/2])
    plt.plot(np.linspace(-60,60,50),[np.nanmean(pMeanFRSmoo[22-5:22]) for i in np.linspace(-60,60,50)],color = 'k',lw = 1,ls = '--', label = 'P')
    plt.plot(np.linspace(-60,60,50),[np.nanmean(vpHandMeanFRSmoo[22-5:22]) for i in np.linspace(-60,60,50)],color = 'gray',lw = 1,ls = '--',label = 'VP')

    plt.plot(disparitysEvent,np.nanmean(vpcHandMeanFRSmoo[:,range(22-5,22)],1),'.',markersize = 5,color = '#ff7f0e',label = 'VPC')
    plt.plot(rangeDireDeg,vpcHandFRMeanFitting,color = '#ff7f0e',lw = 1)
    plt.legend(loc = 4,bbox_to_anchor=[1.1,0],labelspacing = 0.1,columnspacing = .05)
    plt.xlabel('Disparity (deg)')
    plt.ylabel('Firing rate (Hz)')
    
    #save figure
    fileName = 'fig3D_tuningVPPVPCHoldingPeriodParetalLabel'+key[0].pdf'
    # plt.savefig(fileName,dpi=300)
#%%fig3E
handPosition0 = np.linspace(-60,50,60)
handPosition = np.array([math.radians(i) for i in handPosition0])

bVP = 0.1
aVP = 10
uVP = -0.1

bP = 0.4
aP = 7
uP = 0.05

simuVPTuni = vonMisesFunction(handPosition,bVP,aVP,uVP)
simuPTuni = vonMisesFunction(handPosition,bP,aP,uP)

targPosi = 0
disp = 35#deg
drif = 25#deg
PHand = targPosi - drif
VPHand = PHand + disp
muVP = vonMisesFunction(math.radians(VPHand),bVP,aVP,uVP)
muP = vonMisesFunction(math.radians(PHand),bP,aP,uP)
VPCFiri = 9
firiRate = np.linspace(0,16,17)
VPProb = [getPossionPDF(muVP,firiRate[i]) for i in range(len(firiRate))]
PProb = [getPossionPDF(muP,firiRate[i]) for i in range(len(firiRate))]
VPProbInVPC = getPossionPDF(muVP,VPCFiri) 
PProbInVPC = getPossionPDF(muP,VPCFiri) 
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    plt.figure(figsize=(7.25/2.5,3/2))
    plt.subplot(121)
    plt.plot(handPosition0,simuVPTuni,color = colors[0])
    plt.plot(np.linspace(-60,VPHand,50), [muVP for i in range(50)],'--',linewidth = 1,color = colors[0])
    plt.plot(handPosition0,simuPTuni,color = colors[1])
    plt.plot(np.linspace(-60,PHand,50), [muP for i in range(50)],'--',linewidth = 1,color = colors[1])
    plt.ylim([3,12])
    plt.xlim([-60,60])
    plt.plot(0,11,marker = 'o',color = 'r',alpha = 1,markersize = 5)
    prop = dict(arrowstyle="-|>,head_width=0.2,head_length=0.4",
                facecolor='gray',
                linewidth = 1,
                ec='gray',
            shrinkA=0,
            shrinkB=0)
    plt.annotate("", xy=(PHand,muP), xytext=(0,3), 
                 arrowprops=dict(arrowstyle="-|>,head_width=0.2,head_length=0.4",
                facecolor=colors[1],
                linewidth = 1,
                ec=colors[1],
            shrinkA=0,
            shrinkB=0))
    plt.annotate("", xy=(VPHand,muVP), 
                 xytext=(0,3), 
                 arrowprops=dict(arrowstyle="-|>,head_width=0.2,head_length=0.4",
                facecolor=colors[0],
                linewidth = 1,
                ec=colors[0],
            shrinkA=0,
            shrinkB=0))
    plt.annotate("", xy=(0,11), xytext=(0,3), 
                 arrowprops=dict(arrowstyle="-|>,head_width=0.2,head_length=0.4",
                facecolor='k',
                linewidth = 1,
                ec='k',
            shrinkA=0,
            shrinkB=0))
    # plt.text(-50,muVP, s = '$H_{2O}$')    
    plt.xlabel('Arm location (deg)')
    plt.ylabel('Firing rate (Hz)')
    # plt.show()
    plt.subplot(222)
    plt.plot(firiRate,VPProb,'-o',markersize = 1,color = colors[0])
    plt.plot(np.linspace(0,VPCFiri,20), [VPProbInVPC for i in range(20)],'--',linewidth = 1,color = 'k')
    plt.plot([VPCFiri for i in range(20)], np.linspace(0,VPProbInVPC,20),'--',linewidth = 1,color = 'k')
    plt.xlim(left = 0)
    plt.ylim(bottom = 0)
    plt.title('VP',color = colors[0])
    # plt.show()
    plt.subplot(224)
    plt.plot(firiRate,PProb,'-o',markersize = 1,color = colors[1])
    plt.plot(np.linspace(0,VPCFiri,20), [ PProbInVPC for i in range(20)],'--',linewidth = 1,color = 'k')
    plt.plot([VPCFiri for i in range(20)], np.linspace(0,PProbInVPC,20),'--',linewidth = 1,color = 'k')
    plt.xlabel('Firing rate (Hz)')
    plt.ylabel('Pr (X=k)')
    plt.xlim(left = 0)
    plt.ylim(bottom = 0)
    plt.title('P',color = colors[1])
    plt.tight_layout()
    fileName = 'fig3E_vpWeightCalcPlot.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig3E premotor
disp = np.array([-45,-35,-20,-10,0,10,20,35,45])
drif = np.array([45,35,20,10,0,-10,-20,-35,-45])
cm = plt.cm.get_cmap('RdBu_r')
with plt.style.context('style_paper.mplstyle'):
    plt.figure(figsize=[3.54/1.5,3.54/2])
    for i in range(vpWeigDrifDispPMC.shape[0]):
        sc = plt.scatter(vpWeigDrifDispPMC[i,2]+np.random.normal(0,1,1),
                    vpWeigDrifDispPMC[i,1],
                     s = 5, 
                     c = vpWeigDrifDispPMC[i,0],
                     vmin = np.median(vpWeigDrifDispPMC[:,0])-2*np.std(vpWeigDrifDispPMC[:,0]),#PMC 0.48968453724492983
                     vmax = np.median(vpWeigDrifDispPMC[:,0])+0.8*np.std(vpWeigDrifDispPMC[:,0]),#PMC 0.5572382328424924
                     
                     #
                       # vmin = 0.35,
                       # vmax = .65,
                      # vmin = 0,
                      # vmax = .4,
                     cmap = cm)
    cb = plt.colorbar(sc)
    cb.outline.set_visible(False)
    plt.xlim([-55,55])
    plt.xlabel('Disparity (deg)')
    plt.ylabel('Drift (deg)')    
    fileName = 'fig3F_examCaulInfeNeurPremotor.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%
disp = np.array([-45,-35,-20,-10,0,10,20,35,45])
drif = np.array([45,35,20,10,0,-10,-20,-35,-45])
cm = plt.cm.get_cmap('RdBu_r')
with plt.style.context('style_paper.mplstyle'):
    plt.figure(figsize=[3.54/1.5,3.54/2])
    for i in range(vpWeigDrifDispArea5.shape[0]):
        sc = plt.scatter(vpWeigDrifDispArea5[i,2]+np.random.normal(0,1,1),
                    vpWeigDrifDispArea5[i,1],
                     s = 5, 
                     c = vpWeigDrifDispArea5[i,0],
                       # vmin = 0.35,
                       # vmax = .65,
                        vmin = 0,
                        vmax = .4,
                     cmap = cm)
    cb = plt.colorbar(sc)
    cb.outline.set_visible(False)
    plt.xlim([-55,55])
    plt.xlabel('Disparity (deg)')
    plt.ylabel('Drift (deg)')
    fileName = 'fig3F_examCaulInfeNeurParietal.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig3G
inputVPWeight = copy.deepcopy(np.nanmean(popuPMCVPWeight,0)) # Premotor
# inputVPWeight = copy.deepcopy(popuArea5VPWeight) # Parietal
disp = np.array([-45,-35,-20,-10,0,10,20,35,45])
drif = np.array([45,35,20,10,0,-10,-20,-35,-45])
cm = plt.cm.get_cmap('RdBu_r')
with plt.style.context('style_paper.mplstyle'):
    plt.figure(figsize=[3.54/1.5,3.54/2])

    # norm = mpl.colors.Normalize(vmin=.2,vmax=.8)
    for i in range(inputVPWeight.shape[0]):
        for j in range(inputVPWeight.shape[1]):
            for k in range(inputVPWeight.shape[2]):
                if ~np.isnan(inputVPWeight[i,j,k]):
                    # c = colors[np.where(np.abs(norm(PMCWeight[i,j,k])-stepColor) == np.nanmin(np.abs(norm(PMCWeight[i,j,k])-stepColor)))[0],:]                    
                    # m = m+1
                    sc = plt.scatter(drif[j]+np.random.normal(0,2,1),
                                     disp[i]+np.random.normal(0,2,1),
                                     s = 1, 
                                     c = inputVPWeight[i,j,k],
                                     vmin = .45,
                                     vmax = .55,
                                     cmap = cm)
    cb = plt.colorbar(sc)
    cb.outline.set_visible(False)
    cb.ax.tick_params(labelsize=8)
    plt.xlabel('Disparity (deg)',fontsize = 8)
    plt.ylabel('Drift (deg)',fontsize = 8)
    plt.xticks(rotation=0,fontsize = 8)
    plt.yticks(rotation=0,fontsize = 8)
    fileName = 'fig3G_popuVPWeigPMC.pdf'
    # plt.savefig(fileName,dpi = 600)    
plt.show()
#%%
#3G
causInfeNeurFrac = [caulInfeNeurPMCNum/412*100,caulInfeNeurArea5Num/238*100]
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"][3:]
    x_labels = ['Premotor\n(n=412)','Parietal\n(n=238)']
    width = 0.3
    x = np.arange(2)
    fig, ax = plt.subplots(figsize=[3.54/1.5,3.54/2])
    for i in range(2):
        ax.bar(x[i], [causInfeNeurFrac[i]],
        width, color = 'gray')
    ax.set_ylabel('CI neuron (%)',fontsize = 8)
    plt.xticks(rotation=0,fontsize = 8)
    plt.yticks(rotation=0,fontsize = 8)
    ax.set_xlim([-.5,1.5])
    ax.set_xticks(x)
    ax.set_xticklabels(x_labels[:2])
    fig.tight_layout()
    fileName = 'fig3G_causInfeNeurFrac.pdf'
    plt.savefig(fileName,dpi = 600) 
plt.show()

#%%fig3H
index = ~np.isnan(vpWeighExamParietal)
regr = linear_model.LinearRegression()
regr.fit(pCommonExamParietal[index].reshape(-1, 1), vpWeighExamParietal[index].reshape(-1, 1))
diabetes_y_pred = regr.predict(pCommonExamParietal[index].reshape(-1, 1))
slope, intercept, r_value, p_value, std_err = stats.linregress(pCommonExamParietal[index], vpWeighExamParietal[index])
print(r_value, p_value)
with plt.style.context('style_paper.mplstyle'):
    f, ax1= plt.subplots(ncols=1, nrows=1, sharex=True,figsize=[7.25/2.5/1.5,3/2])
    sns.regplot(pCommonExamParietal[index],vpWeighExamParietal[index],  color ='black',scatter_kws={"s": 10,
                                                                            'edgecolors':[]},
                line_kws={'label':"y={0:.1f}x{1:.1f}".format(slope,intercept)})
    plt.xlabel('$P_{com}$')
    plt.ylabel('VP Weight')
    plt.tight_layout()
    plt.xticks(np.linspace(0,1,3))
    plt.xlim([0.18,1.04])
    fileName = 'fig3H_examVpWeigPCommCorrParietal.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
